import tensorflow as tf
import torch
from transformers import TFGPT2LMHeadModel, GPT2Tokenizer
import mysql.connector
import re

gptModel = ""
tokenizer = GPT2Tokenizer.from_pretrained(gptModel)
# add the EOS token as PAD token to avoid warnings
model = TFGPT2LMHeadModel.from_pretrained(gptModel, pad_token_id=tokenizer.eos_token_id,from_pt=True)


def dbConnection():
    mydb = mysql.connector.connect(
        host="",
        user="",
        passwd="",
        database=""
    )

    mycursor = mydb.cursor()
    mydb.autocommit = True

    return mydb, mycursor


def finish(mydb, mycursor):
    mycursor.close()
    mydb.close()

def getData():

    allData = []
    mydb, mycursor = dbConnection()
    sql = "SELECT * from gptSeeds where dataset='news' and label = '5.5';"

    mycursor.execute(sql)
    res = mycursor.fetchall()
    finish(mydb, mycursor)
    for row in res:
        textType = row[0]
        textId = row[1]
        text = clean(row[2].decode("utf-8"))
        label = row[3]
        allData.append([textType,textId,text,label])

    return allData

def clean(txt):
    txt = txt.strip().replace('\n','').replace('\t','').replace('\r','').replace("''","'").strip()
    txt = re.sub("-\n|--|\n-", " ", txt)
    txt = re.sub("(\\*)+", " ", txt)
    
    tokens = []
    for t in txt.split():
        t = t.strip()
        tokens.append(t)

    newTxt = " ".join(t for t in tokens)
    newTxt = newTxt.rstrip().lstrip()
    return newTxt

def sampling(input_ids, sent, maxl):
    allTxt = []
    #to generate radom text
    #tf.random.set_seed(42)
    sample_output_top_p = model.generate(
        input_ids,
        do_sample=True,
        max_length=maxl,
        top_k=50,
        top_p=0.90,
        num_return_sequences=2
    )
    for i in range(0,2):
        txt = tokenizer.decode(sample_output_top_p[i], skip_special_tokens=True)
        txt = txt.replace(sent, '').strip().replace('\n', '').replace('\t', '')
        allTxt.append(txt)

    return allTxt

def safeRes(result,experiment,dataId,label):
    count = 0
    mydb, mycursor = dbConnection()
    for s in result:
        s = s.replace("'", "").replace('"', '').replace("\t", " ").replace("\n", " ")
        s = re.escape(s)
        count = count + 1
        id = dataId + "." + str(count)
        sql = "INSERT INTO humaninloopGeneratedOne(experiment, comeFrom, dataId, data, label, gptModel) VALUES ('"+experiment+"', '"+dataId+"', '"+id+"', '"+s+"','"+label+"','perlabel');"
       
        mycursor.execute(sql)
    finish(mydb, mycursor)


def  shortenNews(clData):
    #s1 = clData.split(". ")[0]
    #s2 = clData.split(". ")[1]
    #shData = s1+". "+s2
    shData = clData[:1000]
    #print(shData)
    return shData

def splitHalves(shortData):
    if ". " in shortData:
        res_first = shortData.split(". ")[0]
    else:
        res_first = shortData[0:len(shortData) // 2]
    return res_first

def main():
    allData = getData()
    for row in allData:
        
        dataset = row[0]
        dataId = row[1]
        data = row[2]
        label = row[3]

        input_ids = tokenizer.encode(clData, return_tensors='tf')
        lengthTokens = input_ids.shape[1]


        if lengthTokens >= 1000:
            shortData = shortenNews(clData)
            input_idsShort = tokenizer.encode(clData, return_tensors='tf')
            if input_idsShort.shape[1] >= 1000:
                shortDataA = splitHalves(shortData)
                input_idsShortA = tokenizer.encode(shortDataA, return_tensors='tf')
                result = sampling(input_idsShortA, shortDataA, 1024)
                safeRes(result, dataset, experiment, dataId, label)
            else:
                result = sampling(input_idsShort, shortData, 1024)
                safeRes(result,dataset,experiment,dataId,label)

        else:
            result = sampling(input_ids, clData, 1024)
            safeRes(result,dataset,experiment,dataId,label)



main()